import json
import logging
import traceback
from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List

from .environment import Environment
from .runner import (
    ProfileLogs,
    ResultComparison,
    benchmark_server,
    compare_server_to_full,
)
from .specification import Specification


LOG: logging.Logger = logging.getLogger(__name__)


@dataclass
class Sample:
    integers: Dict[str, int]
    normals: Dict[str, str]


class RunnerResult(ABC):
    _input: Specification

    def __init__(self, input: Specification) -> None:
        self._input = input

    @property
    def input(self) -> Specification:
        return self._input

    @abstractmethod
    def get_status(self) -> str:
        raise NotImplementedError

    @abstractmethod
    def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]:
        raise NotImplementedError

    @abstractmethod
    def to_logger_sample(self) -> Sample:
        raise NotImplementedError


class ExceptionalRunnerResult(RunnerResult):
    _trace: str

    def __init__(self, input: Specification, trace: str) -> None:
        super().__init__(input)
        self._trace = trace

    def get_status(self) -> str:
        return "exception"

    def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]:
        return {"status": self.get_status(), "trace": self._trace}

    def to_logger_sample(self) -> Sample:
        return Sample(
            normals={
                "status": self.get_status(),
                "input": json.dumps(self.input.to_json()),
                "exception": self._trace,
            },
            integers={},
        )


class FinishedRunnerResult(RunnerResult, metaclass=ABCMeta):
    _output: ResultComparison

    def __init__(self, input: Specification, output: ResultComparison) -> None:
        super().__init__(input)
        self._output = output

    def to_logger_sample(self) -> Sample:
        full_check_time = self._output.profile_logs.full_check_time()
        incremental_check_time = (
            self._output.profile_logs.total_incremental_check_time()
        )
        return Sample(
            normals={
                "status": self.get_status(),
                "input": json.dumps(self.input.to_json()),
            },
            integers={
                "full_check_time": full_check_time,
                "incremental_check_time": incremental_check_time,
            },
        )


class PassedRunnerResult(FinishedRunnerResult):
    def get_status(self) -> str:
        return "pass"

    def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]:
        # Don't bother include the input specification in the result if the test passes.
        return {
            "status": self.get_status(),
            "output": self._output.to_json(dont_show_discrepancy),
        }


class FailedRunnerResult(FinishedRunnerResult):
    def get_status(self) -> str:
        return "fail"

    def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]:
        return {
            "status": self.get_status(),
            "input": self.input.to_json(),
            "output": self._output.to_json(dont_show_discrepancy),
        }


class BenchmarkResult(RunnerResult):
    _profile_logs: ProfileLogs

    def __init__(self, input: Specification, profile_logs: ProfileLogs) -> None:
        super().__init__(input)
        self._profile_logs = profile_logs

    def get_status(self) -> str:
        return "benchmark"

    def to_json(self, dont_show_discrepancy: bool) -> Dict[str, Any]:
        return {
            "status": self.get_status(),
            "input": self.input.to_json(),
            "time": self._profile_logs.total_incremental_check_time(),
            "profile_logs": self._profile_logs.to_json(),
        }

    def to_logger_sample(self) -> Sample:
        incremental_check_time = self._profile_logs.total_incremental_check_time()
        return Sample(
            normals={
                "status": self.get_status(),
                "input": json.dumps(self.input.to_json()),
            },
            integers={"incremental_check_time": incremental_check_time},
        )

    def profile_logs(self) -> ProfileLogs:
        return self._profile_logs


def run_single_test(environment: Environment, input: Specification) -> RunnerResult:
    try:
        LOG.info(f"Running test on state '{input.old_state}' vs '{input.new_state}'")
        output = compare_server_to_full(environment, input)
        if output.discrepancy is None:
            result = PassedRunnerResult(input, output)
        else:
            result = FailedRunnerResult(input, output)
    except Exception:
        result = ExceptionalRunnerResult(input, traceback.format_exc())

    LOG.info(f"Test finished with status = {result.get_status()}")
    return result


def run_batch_test(
    environment: Environment, inputs: Iterable[Specification]
) -> List[RunnerResult]:
    return [run_single_test(environment, input) for input in inputs]


def run_single_benchmark(
    environment: Environment, input: Specification
) -> RunnerResult:
    try:
        LOG.info(
            f"Running benchmark on state '{input.old_state}' vs '{input.new_state}'"
        )
        output = benchmark_server(environment, input)
        result = BenchmarkResult(input, output)
    except Exception:
        result = ExceptionalRunnerResult(input, traceback.format_exc())

    LOG.info(f"Test finished with status = {result.get_status()}")
    return result


def run_batch_benchmark(
    environment: Environment, inputs: Iterable[Specification]
) -> List[RunnerResult]:
    return [run_single_benchmark(environment, input) for input in inputs]
